{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import Scalar, draw_graph\n",
    "from linear_model import Linear, mse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"353pt\" height=\"517pt\"\n",
       " viewBox=\"0.00 0.00 353.38 516.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 512.797)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-512.797 349.378,-512.797 349.378,4 -4,4\"/>\n",
       "<!-- 140455917317136backward -->\n",
       "<g id=\"node1\" class=\"node\"><title>140455917317136backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M99.5867,-337.156C99.5867,-337.156 151.792,-337.156 151.792,-337.156 157.792,-337.156 163.792,-343.156 163.792,-349.156 163.792,-349.156 163.792,-384.078 163.792,-384.078 163.792,-390.078 157.792,-396.078 151.792,-396.078 151.792,-396.078 99.5867,-396.078 99.5867,-396.078 93.5867,-396.078 87.5867,-390.078 87.5867,-384.078 87.5867,-384.078 87.5867,-349.156 87.5867,-349.156 87.5867,-343.156 93.5867,-337.156 99.5867,-337.156\"/>\n",
       "<text text-anchor=\"middle\" x=\"125.689\" y=\"-384.078\" font-family=\"Menlo\" font-size=\"10.00\">grad=4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"87.5867,-376.438 163.792,-376.438 \"/>\n",
       "<text text-anchor=\"middle\" x=\"125.689\" y=\"-364.438\" font-family=\"Menlo\" font-size=\"10.00\">value=4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"87.5867,-356.797 163.792,-356.797 \"/>\n",
       "<text text-anchor=\"middle\" x=\"125.689\" y=\"-344.797\" font-family=\"Menlo\" font-size=\"10.00\">&#45;</text>\n",
       "</g>\n",
       "<!-- 140455917316464backward -->\n",
       "<g id=\"node9\" class=\"node\"><title>140455917316464backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"34.6892\" cy=\"-254.398\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"34.6892\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">y2=4.00</text>\n",
       "</g>\n",
       "<!-- 140455917317136backward&#45;&gt;140455917316464backward -->\n",
       "<g id=\"edge7\" class=\"edge\"><title>140455917317136backward&#45;&gt;140455917316464backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M102.018,-336.946C84.7445,-316.025 61.971,-288.442 47.8333,-271.319\"/>\n",
       "</g>\n",
       "<!-- 140455917317088backward -->\n",
       "<g id=\"node13\" class=\"node\"><title>140455917317088backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M99.5867,-224.938C99.5867,-224.938 151.792,-224.938 151.792,-224.938 157.792,-224.938 163.792,-230.938 163.792,-236.938 163.792,-236.938 163.792,-271.859 163.792,-271.859 163.792,-277.859 157.792,-283.859 151.792,-283.859 151.792,-283.859 99.5867,-283.859 99.5867,-283.859 93.5867,-283.859 87.5867,-277.859 87.5867,-271.859 87.5867,-271.859 87.5867,-236.938 87.5867,-236.938 87.5867,-230.938 93.5867,-224.938 99.5867,-224.938\"/>\n",
       "<text text-anchor=\"middle\" x=\"125.689\" y=\"-271.859\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"87.5867,-264.219 163.792,-264.219 \"/>\n",
       "<text text-anchor=\"middle\" x=\"125.689\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"87.5867,-244.578 163.792,-244.578 \"/>\n",
       "<text text-anchor=\"middle\" x=\"125.689\" y=\"-232.578\" font-family=\"Menlo\" font-size=\"10.00\">+</text>\n",
       "</g>\n",
       "<!-- 140455917317136backward&#45;&gt;140455917317088backward -->\n",
       "<g id=\"edge14\" class=\"edge\"><title>140455917317136backward&#45;&gt;140455917317088backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M125.689,-336.946C125.689,-323.934 125.689,-308.345 125.689,-294.299\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"129.189,-293.926 125.689,-283.926 122.189,-293.926 129.189,-293.926\"/>\n",
       "<text text-anchor=\"middle\" x=\"146.761\" y=\"-307.456\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140455917316656backward -->\n",
       "<g id=\"node2\" class=\"node\"><title>140455917316656backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M243.587,-112.719C243.587,-112.719 295.792,-112.719 295.792,-112.719 301.792,-112.719 307.792,-118.719 307.792,-124.719 307.792,-124.719 307.792,-159.641 307.792,-159.641 307.792,-165.641 301.792,-171.641 295.792,-171.641 295.792,-171.641 243.587,-171.641 243.587,-171.641 237.587,-171.641 231.587,-165.641 231.587,-159.641 231.587,-159.641 231.587,-124.719 231.587,-124.719 231.587,-118.719 237.587,-112.719 243.587,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"269.689\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"231.587,-152 307.792,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"269.689\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"231.587,-132.359 307.792,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"269.689\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140455917316704backward -->\n",
       "<g id=\"node4\" class=\"node\"><title>140455917316704backward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"136.587,-0.5 136.587,-59.4219 212.792,-59.4219 212.792,-0.5 136.587,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"174.689\" y=\"-47.4219\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;9.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"136.587,-39.7812 212.792,-39.7812 \"/>\n",
       "<text text-anchor=\"middle\" x=\"174.689\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"136.587,-20.1406 212.792,-20.1406 \"/>\n",
       "<text text-anchor=\"middle\" x=\"174.689\" y=\"-8.14062\" font-family=\"Menlo\" font-size=\"10.00\">a</text>\n",
       "</g>\n",
       "<!-- 140455917316656backward&#45;&gt;140455917316704backward -->\n",
       "<g id=\"edge4\" class=\"edge\"><title>140455917316656backward&#45;&gt;140455917316704backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M240.581,-112.634C234.804,-106.671 228.879,-100.335 223.546,-94.2188 216.214,-85.811 208.669,-76.4458 201.768,-67.5536\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"204.5,-65.3644 195.636,-59.5612 198.946,-69.6255 204.5,-65.3644\"/>\n",
       "<text text-anchor=\"middle\" x=\"244.761\" y=\"-83.0187\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.50</text>\n",
       "</g>\n",
       "<!-- 140455917316512backward -->\n",
       "<g id=\"node10\" class=\"node\"><title>140455917316512backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"269.689\" cy=\"-29.9609\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"269.689\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">x1=1.50</text>\n",
       "</g>\n",
       "<!-- 140455917316656backward&#45;&gt;140455917316512backward -->\n",
       "<g id=\"edge1\" class=\"edge\"><title>140455917316656backward&#45;&gt;140455917316512backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M269.689,-112.509C269.689,-92.0979 269.689,-65.346 269.689,-48.1524\"/>\n",
       "</g>\n",
       "<!-- 140455917317184backward -->\n",
       "<g id=\"node3\" class=\"node\"><title>140455917317184backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M150.587,-449.375C150.587,-449.375 202.792,-449.375 202.792,-449.375 208.792,-449.375 214.792,-455.375 214.792,-461.375 214.792,-461.375 214.792,-496.297 214.792,-496.297 214.792,-502.297 208.792,-508.297 202.792,-508.297 202.792,-508.297 150.587,-508.297 150.587,-508.297 144.587,-508.297 138.587,-502.297 138.587,-496.297 138.587,-496.297 138.587,-461.375 138.587,-461.375 138.587,-455.375 144.587,-449.375 150.587,-449.375\"/>\n",
       "<text text-anchor=\"middle\" x=\"176.689\" y=\"-496.297\" font-family=\"Menlo\" font-size=\"10.00\">grad=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"138.587,-488.656 214.792,-488.656 \"/>\n",
       "<text text-anchor=\"middle\" x=\"176.689\" y=\"-476.656\" font-family=\"Menlo\" font-size=\"10.00\">value=8.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"138.587,-469.016 214.792,-469.016 \"/>\n",
       "<text text-anchor=\"middle\" x=\"176.689\" y=\"-457.016\" font-family=\"Menlo\" font-size=\"10.00\">mse</text>\n",
       "</g>\n",
       "<!-- 140455917317184backward&#45;&gt;140455917317136backward -->\n",
       "<g id=\"edge13\" class=\"edge\"><title>140455917317184backward&#45;&gt;140455917317136backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M163.423,-449.165C157.225,-435.77 149.763,-419.645 143.117,-405.282\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"146.265,-403.75 138.889,-396.145 139.912,-406.69 146.265,-403.75\"/>\n",
       "<text text-anchor=\"middle\" x=\"171.547\" y=\"-419.675\" font-family=\"Menlo\" font-size=\"14.00\">4.00</text>\n",
       "</g>\n",
       "<!-- 140455917316752backward -->\n",
       "<g id=\"node5\" class=\"node\"><title>140455917316752backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M193.587,-337.156C193.587,-337.156 245.792,-337.156 245.792,-337.156 251.792,-337.156 257.792,-343.156 257.792,-349.156 257.792,-349.156 257.792,-384.078 257.792,-384.078 257.792,-390.078 251.792,-396.078 245.792,-396.078 245.792,-396.078 193.587,-396.078 193.587,-396.078 187.587,-396.078 181.587,-390.078 181.587,-384.078 181.587,-384.078 181.587,-349.156 181.587,-349.156 181.587,-343.156 187.587,-337.156 193.587,-337.156\"/>\n",
       "<text text-anchor=\"middle\" x=\"219.689\" y=\"-384.078\" font-family=\"Menlo\" font-size=\"10.00\">grad=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"181.587,-376.438 257.792,-376.438 \"/>\n",
       "<text text-anchor=\"middle\" x=\"219.689\" y=\"-364.438\" font-family=\"Menlo\" font-size=\"10.00\">value=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"181.587,-356.797 257.792,-356.797 \"/>\n",
       "<text text-anchor=\"middle\" x=\"219.689\" y=\"-344.797\" font-family=\"Menlo\" font-size=\"10.00\">&#45;</text>\n",
       "</g>\n",
       "<!-- 140455917317184backward&#45;&gt;140455917316752backward -->\n",
       "<g id=\"edge3\" class=\"edge\"><title>140455917317184backward&#45;&gt;140455917316752backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M187.875,-449.165C193.051,-435.898 199.272,-419.952 204.835,-405.693\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"208.186,-406.733 208.56,-396.145 201.665,-404.189 208.186,-406.733\"/>\n",
       "<text text-anchor=\"middle\" x=\"218.547\" y=\"-419.675\" font-family=\"Menlo\" font-size=\"14.00\">1.00</text>\n",
       "</g>\n",
       "<!-- 140455917316896backward -->\n",
       "<g id=\"node7\" class=\"node\"><title>140455917316896backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M193.587,-224.938C193.587,-224.938 245.792,-224.938 245.792,-224.938 251.792,-224.938 257.792,-230.938 257.792,-236.938 257.792,-236.938 257.792,-271.859 257.792,-271.859 257.792,-277.859 251.792,-283.859 245.792,-283.859 245.792,-283.859 193.587,-283.859 193.587,-283.859 187.587,-283.859 181.587,-277.859 181.587,-271.859 181.587,-271.859 181.587,-236.938 181.587,-236.938 181.587,-230.938 187.587,-224.938 193.587,-224.938\"/>\n",
       "<text text-anchor=\"middle\" x=\"219.689\" y=\"-271.859\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"181.587,-264.219 257.792,-264.219 \"/>\n",
       "<text text-anchor=\"middle\" x=\"219.689\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"181.587,-244.578 257.792,-244.578 \"/>\n",
       "<text text-anchor=\"middle\" x=\"219.689\" y=\"-232.578\" font-family=\"Menlo\" font-size=\"10.00\">+</text>\n",
       "</g>\n",
       "<!-- 140455917316752backward&#45;&gt;140455917316896backward -->\n",
       "<g id=\"edge6\" class=\"edge\"><title>140455917316752backward&#45;&gt;140455917316896backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M219.689,-336.946C219.689,-323.934 219.689,-308.345 219.689,-294.299\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"223.189,-293.926 219.689,-283.926 216.189,-293.926 223.189,-293.926\"/>\n",
       "<text text-anchor=\"middle\" x=\"240.761\" y=\"-307.456\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140455917316416backward -->\n",
       "<g id=\"node8\" class=\"node\"><title>140455917316416backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"310.689\" cy=\"-254.398\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"310.689\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">y1=1.00</text>\n",
       "</g>\n",
       "<!-- 140455917316752backward&#45;&gt;140455917316416backward -->\n",
       "<g id=\"edge9\" class=\"edge\"><title>140455917316752backward&#45;&gt;140455917316416backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M249.134,-336.973C254.801,-331.071 260.563,-324.783 265.689,-318.656 278.402,-303.461 291.353,-284.859 300.076,-271.772\"/>\n",
       "</g>\n",
       "<!-- 140455917316848backward -->\n",
       "<g id=\"node6\" class=\"node\"><title>140455917316848backward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"136.587,-112.719 136.587,-171.641 212.792,-171.641 212.792,-112.719 136.587,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"174.689\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;5.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"136.587,-152 212.792,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"174.689\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"136.587,-132.359 212.792,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"174.689\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">b</text>\n",
       "</g>\n",
       "<!-- 140455917316896backward&#45;&gt;140455917316656backward -->\n",
       "<g id=\"edge2\" class=\"edge\"><title>140455917316896backward&#45;&gt;140455917316656backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M237.2,-224.756C240.543,-218.8 243.874,-212.492 246.689,-206.438 250.372,-198.517 253.866,-189.829 256.966,-181.496\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"260.31,-182.539 260.412,-171.945 253.726,-180.163 260.31,-182.539\"/>\n",
       "<text text-anchor=\"middle\" x=\"274.761\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140455917316896backward&#45;&gt;140455917316848backward -->\n",
       "<g id=\"edge5\" class=\"edge\"><title>140455917316896backward&#45;&gt;140455917316848backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M207.983,-224.727C202.567,-211.46 196.056,-195.514 190.235,-181.255\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"193.357,-179.642 186.336,-171.707 186.876,-182.288 193.357,-179.642\"/>\n",
       "<text text-anchor=\"middle\" x=\"221.761\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140455917317040backward -->\n",
       "<g id=\"node11\" class=\"node\"><title>140455917317040backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M52.5867,-112.719C52.5867,-112.719 104.792,-112.719 104.792,-112.719 110.792,-112.719 116.792,-118.719 116.792,-124.719 116.792,-124.719 116.792,-159.641 116.792,-159.641 116.792,-165.641 110.792,-171.641 104.792,-171.641 104.792,-171.641 52.5867,-171.641 52.5867,-171.641 46.5867,-171.641 40.5867,-165.641 40.5867,-159.641 40.5867,-159.641 40.5867,-124.719 40.5867,-124.719 40.5867,-118.719 46.5867,-112.719 52.5867,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"78.6892\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"40.5867,-152 116.792,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"78.6892\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"40.5867,-132.359 116.792,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"78.6892\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140455917317040backward&#45;&gt;140455917316704backward -->\n",
       "<g id=\"edge10\" class=\"edge\"><title>140455917317040backward&#45;&gt;140455917316704backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M103.661,-112.509C115.661,-98.7314 130.178,-82.065 142.952,-67.3991\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"145.913,-69.3279 149.842,-59.4883 140.635,-64.7303 145.913,-69.3279\"/>\n",
       "<text text-anchor=\"middle\" x=\"153.761\" y=\"-83.0187\" font-family=\"Menlo\" font-size=\"14.00\">&#45;8.00</text>\n",
       "</g>\n",
       "<!-- 140455917316560backward -->\n",
       "<g id=\"node12\" class=\"node\"><title>140455917316560backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"78.6892\" cy=\"-29.9609\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"78.6892\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">x2=2.00</text>\n",
       "</g>\n",
       "<!-- 140455917317040backward&#45;&gt;140455917316560backward -->\n",
       "<g id=\"edge12\" class=\"edge\"><title>140455917317040backward&#45;&gt;140455917316560backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M78.6892,-112.509C78.6892,-92.0979 78.6892,-65.346 78.6892,-48.1524\"/>\n",
       "</g>\n",
       "<!-- 140455917317088backward&#45;&gt;140455917316848backward -->\n",
       "<g id=\"edge8\" class=\"edge\"><title>140455917317088backward&#45;&gt;140455917316848backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M134.971,-224.667C138.796,-213.705 143.512,-201.193 148.546,-190.141 149.919,-187.126 151.414,-184.044 152.97,-180.972\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"156.205,-182.341 157.765,-171.862 150.01,-179.081 156.205,-182.341\"/>\n",
       "<text text-anchor=\"middle\" x=\"169.761\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140455917317088backward&#45;&gt;140455917317040backward -->\n",
       "<g id=\"edge11\" class=\"edge\"><title>140455917317088backward&#45;&gt;140455917317040backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M88.513,-224.687C83.6475,-219.187 79.3758,-213.056 76.5457,-206.438 73.2901,-198.823 72.1369,-190.21 72.1306,-181.847\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"75.6265,-182.013 72.6273,-171.852 68.6351,-181.666 75.6265,-182.013\"/>\n",
       "<text text-anchor=\"middle\" x=\"97.761\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fbe70f963d0>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 定义训练数据\n",
    "x1 = Scalar(1.5, label='x1', requires_grad=False)\n",
    "y1 = Scalar(1.0, label='y1', requires_grad=False)\n",
    "x2 = Scalar(2.0, label='x2', requires_grad=False)\n",
    "y2 = Scalar(4.0, label='y2', requires_grad=False)\n",
    "# 反向传播\n",
    "model = Linear()\n",
    "loss = mse([model.error(x1, y1), model.error(x2, y2)])\n",
    "loss.backward()\n",
    "draw_graph(loss, 'backward')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"288pt\" height=\"629pt\"\n",
       " viewBox=\"0.00 0.00 287.57 629.02\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 625.016)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-625.016 283.568,-625.016 283.568,4 -4,4\"/>\n",
       "<!-- 140455917315120backward -->\n",
       "<g id=\"node1\" class=\"node\"><title>140455917315120backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M44,-224.938C44,-224.938 96.2051,-224.938 96.2051,-224.938 102.205,-224.938 108.205,-230.938 108.205,-236.938 108.205,-236.938 108.205,-271.859 108.205,-271.859 108.205,-277.859 102.205,-283.859 96.2051,-283.859 96.2051,-283.859 44,-283.859 44,-283.859 38,-283.859 32,-277.859 32,-271.859 32,-271.859 32,-236.938 32,-236.938 32,-230.938 38,-224.938 44,-224.938\"/>\n",
       "<text text-anchor=\"middle\" x=\"70.1025\" y=\"-271.859\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"32,-264.219 108.205,-264.219 \"/>\n",
       "<text text-anchor=\"middle\" x=\"70.1025\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"32,-244.578 108.205,-244.578 \"/>\n",
       "<text text-anchor=\"middle\" x=\"70.1025\" y=\"-232.578\" font-family=\"Menlo\" font-size=\"10.00\">+</text>\n",
       "</g>\n",
       "<!-- 140455917315888backward -->\n",
       "<g id=\"node5\" class=\"node\"><title>140455917315888backward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"0,-112.719 0,-171.641 76.2051,-171.641 76.2051,-112.719 0,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"0,-152 76.2051,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"0,-132.359 76.2051,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">b</text>\n",
       "</g>\n",
       "<!-- 140455917315120backward&#45;&gt;140455917315888backward -->\n",
       "<g id=\"edge7\" class=\"edge\"><title>140455917315120backward&#45;&gt;140455917315888backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M59.0371,-224.771C56.8868,-218.759 54.748,-212.425 52.959,-206.438 50.5711,-198.446 48.3232,-189.794 46.3338,-181.529\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"49.7237,-180.654 44.0431,-171.711 42.9068,-182.245 49.7237,-180.654\"/>\n",
       "<text text-anchor=\"middle\" x=\"74.1743\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140455917315024backward -->\n",
       "<g id=\"node10\" class=\"node\"><title>140455917315024backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M106,-112.719C106,-112.719 158.205,-112.719 158.205,-112.719 164.205,-112.719 170.205,-118.719 170.205,-124.719 170.205,-124.719 170.205,-159.641 170.205,-159.641 170.205,-165.641 164.205,-171.641 158.205,-171.641 158.205,-171.641 106,-171.641 106,-171.641 100,-171.641 94,-165.641 94,-159.641 94,-159.641 94,-124.719 94,-124.719 94,-118.719 100,-112.719 106,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"94,-152 170.205,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"94,-132.359 170.205,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140455917315120backward&#45;&gt;140455917315024backward -->\n",
       "<g id=\"edge9\" class=\"edge\"><title>140455917315120backward&#45;&gt;140455917315024backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M86.2305,-224.727C93.8369,-211.205 103.009,-194.9 111.146,-180.435\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"114.203,-182.139 116.055,-171.707 108.102,-178.707 114.203,-182.139\"/>\n",
       "<text text-anchor=\"middle\" x=\"126.174\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140455917314112backward -->\n",
       "<g id=\"node2\" class=\"node\"><title>140455917314112backward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"49,-0.5 49,-59.4219 125.205,-59.4219 125.205,-0.5 49,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"87.1025\" y=\"-47.4219\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;1.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"49,-39.7812 125.205,-39.7812 \"/>\n",
       "<text text-anchor=\"middle\" x=\"87.1025\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"49,-20.1406 125.205,-20.1406 \"/>\n",
       "<text text-anchor=\"middle\" x=\"87.1025\" y=\"-8.14062\" font-family=\"Menlo\" font-size=\"10.00\">a</text>\n",
       "</g>\n",
       "<!-- 140455917316224backward -->\n",
       "<g id=\"node3\" class=\"node\"><title>140455917316224backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M170,-561.594C170,-561.594 222.205,-561.594 222.205,-561.594 228.205,-561.594 234.205,-567.594 234.205,-573.594 234.205,-573.594 234.205,-608.516 234.205,-608.516 234.205,-614.516 228.205,-620.516 222.205,-620.516 222.205,-620.516 170,-620.516 170,-620.516 164,-620.516 158,-614.516 158,-608.516 158,-608.516 158,-573.594 158,-573.594 158,-567.594 164,-561.594 170,-561.594\"/>\n",
       "<text text-anchor=\"middle\" x=\"196.103\" y=\"-608.516\" font-family=\"Menlo\" font-size=\"10.00\">grad=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"158,-600.875 234.205,-600.875 \"/>\n",
       "<text text-anchor=\"middle\" x=\"196.103\" y=\"-588.875\" font-family=\"Menlo\" font-size=\"10.00\">value=0.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"158,-581.234 234.205,-581.234 \"/>\n",
       "<text text-anchor=\"middle\" x=\"196.103\" y=\"-569.234\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140455917315936backward -->\n",
       "<g id=\"node7\" class=\"node\"><title>140455917315936backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M106,-449.375C106,-449.375 158.205,-449.375 158.205,-449.375 164.205,-449.375 170.205,-455.375 170.205,-461.375 170.205,-461.375 170.205,-496.297 170.205,-496.297 170.205,-502.297 164.205,-508.297 158.205,-508.297 158.205,-508.297 106,-508.297 106,-508.297 100,-508.297 94,-502.297 94,-496.297 94,-496.297 94,-461.375 94,-461.375 94,-455.375 100,-449.375 106,-449.375\"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-496.297\" font-family=\"Menlo\" font-size=\"10.00\">grad=0.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"94,-488.656 170.205,-488.656 \"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-476.656\" font-family=\"Menlo\" font-size=\"10.00\">value=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"94,-469.016 170.205,-469.016 \"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-457.016\" font-family=\"Menlo\" font-size=\"10.00\">mse</text>\n",
       "</g>\n",
       "<!-- 140455917316224backward&#45;&gt;140455917315936backward -->\n",
       "<g id=\"edge5\" class=\"edge\"><title>140455917316224backward&#45;&gt;140455917315936backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M179.454,-561.384C171.603,-547.862 162.135,-531.556 153.735,-517.091\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"156.716,-515.254 148.667,-508.363 150.662,-518.769 156.716,-515.254\"/>\n",
       "<text text-anchor=\"middle\" x=\"184.96\" y=\"-531.894\" font-family=\"Menlo\" font-size=\"14.00\">0.50</text>\n",
       "</g>\n",
       "<!-- 140455917316032backward -->\n",
       "<g id=\"node9\" class=\"node\"><title>140455917316032backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"234.103\" cy=\"-478.836\" rx=\"45.43\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"234.103\" y=\"-476.656\" font-family=\"Menlo\" font-size=\"10.00\">input=0.50</text>\n",
       "</g>\n",
       "<!-- 140455917316224backward&#45;&gt;140455917316032backward -->\n",
       "<g id=\"edge6\" class=\"edge\"><title>140455917316224backward&#45;&gt;140455917316032backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M205.987,-561.384C213.068,-540.845 222.363,-513.886 228.286,-496.706\"/>\n",
       "</g>\n",
       "<!-- 140455917315216backward -->\n",
       "<g id=\"node4\" class=\"node\"><title>140455917315216backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M106,-337.156C106,-337.156 158.205,-337.156 158.205,-337.156 164.205,-337.156 170.205,-343.156 170.205,-349.156 170.205,-349.156 170.205,-384.078 170.205,-384.078 170.205,-390.078 164.205,-396.078 158.205,-396.078 158.205,-396.078 106,-396.078 106,-396.078 100,-396.078 94,-390.078 94,-384.078 94,-384.078 94,-349.156 94,-349.156 94,-343.156 100,-337.156 106,-337.156\"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-384.078\" font-family=\"Menlo\" font-size=\"10.00\">grad=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"94,-376.438 170.205,-376.438 \"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-364.438\" font-family=\"Menlo\" font-size=\"10.00\">value=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"94,-356.797 170.205,-356.797 \"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-344.797\" font-family=\"Menlo\" font-size=\"10.00\">&#45;</text>\n",
       "</g>\n",
       "<!-- 140455917315216backward&#45;&gt;140455917315120backward -->\n",
       "<g id=\"edge1\" class=\"edge\"><title>140455917315216backward&#45;&gt;140455917315120backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M110.721,-337.115C106.608,-331.149 102.487,-324.8 98.959,-318.656 94.3731,-310.671 89.9578,-301.827 86.0186,-293.345\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"89.0978,-291.66 81.7898,-283.989 82.719,-294.543 89.0978,-291.66\"/>\n",
       "<text text-anchor=\"middle\" x=\"120.174\" y=\"-307.456\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140455917316416backward -->\n",
       "<g id=\"node6\" class=\"node\"><title>140455917316416backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"161.103\" cy=\"-254.398\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"161.103\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">y1=1.00</text>\n",
       "</g>\n",
       "<!-- 140455917315216backward&#45;&gt;140455917316416backward -->\n",
       "<g id=\"edge3\" class=\"edge\"><title>140455917315216backward&#45;&gt;140455917316416backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M139.646,-336.946C145.05,-316.408 152.143,-289.449 156.664,-272.269\"/>\n",
       "</g>\n",
       "<!-- 140455917315936backward&#45;&gt;140455917315216backward -->\n",
       "<g id=\"edge2\" class=\"edge\"><title>140455917315936backward&#45;&gt;140455917315216backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M132.103,-449.165C132.103,-436.153 132.103,-420.564 132.103,-406.517\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"135.603,-406.145 132.103,-396.145 128.603,-406.145 135.603,-406.145\"/>\n",
       "<text text-anchor=\"middle\" x=\"148.96\" y=\"-419.675\" font-family=\"Menlo\" font-size=\"14.00\">1.00</text>\n",
       "</g>\n",
       "<!-- 140455917316512backward -->\n",
       "<g id=\"node8\" class=\"node\"><title>140455917316512backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"178.103\" cy=\"-29.9609\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"178.103\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">x1=1.50</text>\n",
       "</g>\n",
       "<!-- 140455917315024backward&#45;&gt;140455917314112backward -->\n",
       "<g id=\"edge4\" class=\"edge\"><title>140455917315024backward&#45;&gt;140455917314112backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M120.397,-112.509C114.98,-99.2417 108.47,-83.2956 102.648,-69.0367\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"105.77,-67.4235 98.7497,-59.4883 99.2893,-70.0694 105.77,-67.4235\"/>\n",
       "<text text-anchor=\"middle\" x=\"134.174\" y=\"-83.0187\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.50</text>\n",
       "</g>\n",
       "<!-- 140455917315024backward&#45;&gt;140455917316512backward -->\n",
       "<g id=\"edge8\" class=\"edge\"><title>140455917315024backward&#45;&gt;140455917316512backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M149.885,-112.659C153.207,-106.694 156.461,-100.351 159.103,-94.2188 165.636,-79.0538 170.831,-60.9343 174.126,-47.9823\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fbe710eeb50>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 梯度累积\n",
    "model = Linear()\n",
    "# 使用x1，y1传播一次\n",
    "# 系数0.5是因为梯度累积2次\n",
    "loss = 0.5 * mse([model.error(x1, y1)])\n",
    "loss.backward()\n",
    "draw_graph(loss, 'backward')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"365pt\" height=\"629pt\"\n",
       " viewBox=\"0.00 0.00 364.57 629.02\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 625.016)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-625.016 360.568,-625.016 360.568,4 -4,4\"/>\n",
       "<!-- 140455917316608backward -->\n",
       "<g id=\"node1\" class=\"node\"><title>140455917316608backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M176.99,-449.375C176.99,-449.375 235.215,-449.375 235.215,-449.375 241.215,-449.375 247.215,-455.375 247.215,-461.375 247.215,-461.375 247.215,-496.297 247.215,-496.297 247.215,-502.297 241.215,-508.297 235.215,-508.297 235.215,-508.297 176.99,-508.297 176.99,-508.297 170.99,-508.297 164.99,-502.297 164.99,-496.297 164.99,-496.297 164.99,-461.375 164.99,-461.375 164.99,-455.375 170.99,-449.375 176.99,-449.375\"/>\n",
       "<text text-anchor=\"middle\" x=\"206.103\" y=\"-496.297\" font-family=\"Menlo\" font-size=\"10.00\">grad=0.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"164.99,-488.656 247.215,-488.656 \"/>\n",
       "<text text-anchor=\"middle\" x=\"206.103\" y=\"-476.656\" font-family=\"Menlo\" font-size=\"10.00\">value=16.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"164.99,-469.016 247.215,-469.016 \"/>\n",
       "<text text-anchor=\"middle\" x=\"206.103\" y=\"-457.016\" font-family=\"Menlo\" font-size=\"10.00\">mse</text>\n",
       "</g>\n",
       "<!-- 140455917315552backward -->\n",
       "<g id=\"node10\" class=\"node\"><title>140455917315552backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M180,-337.156C180,-337.156 232.205,-337.156 232.205,-337.156 238.205,-337.156 244.205,-343.156 244.205,-349.156 244.205,-349.156 244.205,-384.078 244.205,-384.078 244.205,-390.078 238.205,-396.078 232.205,-396.078 232.205,-396.078 180,-396.078 180,-396.078 174,-396.078 168,-390.078 168,-384.078 168,-384.078 168,-349.156 168,-349.156 168,-343.156 174,-337.156 180,-337.156\"/>\n",
       "<text text-anchor=\"middle\" x=\"206.103\" y=\"-384.078\" font-family=\"Menlo\" font-size=\"10.00\">grad=4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"168,-376.438 244.205,-376.438 \"/>\n",
       "<text text-anchor=\"middle\" x=\"206.103\" y=\"-364.438\" font-family=\"Menlo\" font-size=\"10.00\">value=4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"168,-356.797 244.205,-356.797 \"/>\n",
       "<text text-anchor=\"middle\" x=\"206.103\" y=\"-344.797\" font-family=\"Menlo\" font-size=\"10.00\">&#45;</text>\n",
       "</g>\n",
       "<!-- 140455917316608backward&#45;&gt;140455917315552backward -->\n",
       "<g id=\"edge9\" class=\"edge\"><title>140455917316608backward&#45;&gt;140455917315552backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M206.103,-449.165C206.103,-436.153 206.103,-420.564 206.103,-406.517\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"209.603,-406.145 206.103,-396.145 202.603,-406.145 209.603,-406.145\"/>\n",
       "<text text-anchor=\"middle\" x=\"222.96\" y=\"-419.675\" font-family=\"Menlo\" font-size=\"14.00\">4.00</text>\n",
       "</g>\n",
       "<!-- 140455917315648backward -->\n",
       "<g id=\"node2\" class=\"node\"><title>140455917315648backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M246,-561.594C246,-561.594 298.205,-561.594 298.205,-561.594 304.205,-561.594 310.205,-567.594 310.205,-573.594 310.205,-573.594 310.205,-608.516 310.205,-608.516 310.205,-614.516 304.205,-620.516 298.205,-620.516 298.205,-620.516 246,-620.516 246,-620.516 240,-620.516 234,-614.516 234,-608.516 234,-608.516 234,-573.594 234,-573.594 234,-567.594 240,-561.594 246,-561.594\"/>\n",
       "<text text-anchor=\"middle\" x=\"272.103\" y=\"-608.516\" font-family=\"Menlo\" font-size=\"10.00\">grad=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"234,-600.875 310.205,-600.875 \"/>\n",
       "<text text-anchor=\"middle\" x=\"272.103\" y=\"-588.875\" font-family=\"Menlo\" font-size=\"10.00\">value=8.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"234,-581.234 310.205,-581.234 \"/>\n",
       "<text text-anchor=\"middle\" x=\"272.103\" y=\"-569.234\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140455917315648backward&#45;&gt;140455917316608backward -->\n",
       "<g id=\"edge6\" class=\"edge\"><title>140455917315648backward&#45;&gt;140455917316608backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M254.934,-561.384C246.837,-547.862 237.073,-531.556 228.411,-517.091\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"231.325,-515.145 223.185,-508.363 225.32,-518.741 231.325,-515.145\"/>\n",
       "<text text-anchor=\"middle\" x=\"259.96\" y=\"-531.894\" font-family=\"Menlo\" font-size=\"14.00\">0.50</text>\n",
       "</g>\n",
       "<!-- 140455917317568backward -->\n",
       "<g id=\"node8\" class=\"node\"><title>140455917317568backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"311.103\" cy=\"-478.836\" rx=\"45.43\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"311.103\" y=\"-476.656\" font-family=\"Menlo\" font-size=\"10.00\">input=0.50</text>\n",
       "</g>\n",
       "<!-- 140455917315648backward&#45;&gt;140455917317568backward -->\n",
       "<g id=\"edge7\" class=\"edge\"><title>140455917315648backward&#45;&gt;140455917317568backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M282.248,-561.384C289.515,-540.845 299.054,-513.886 305.133,-496.706\"/>\n",
       "</g>\n",
       "<!-- 140455917314112backward -->\n",
       "<g id=\"node3\" class=\"node\"><title>140455917314112backward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"0,-0.5 0,-59.4219 76.2051,-59.4219 76.2051,-0.5 0,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-47.4219\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;9.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"0,-39.7812 76.2051,-39.7812 \"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"0,-20.1406 76.2051,-20.1406 \"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-8.14062\" font-family=\"Menlo\" font-size=\"10.00\">a</text>\n",
       "</g>\n",
       "<!-- 140455917316800backward -->\n",
       "<g id=\"node4\" class=\"node\"><title>140455917316800backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M118,-224.938C118,-224.938 170.205,-224.938 170.205,-224.938 176.205,-224.938 182.205,-230.938 182.205,-236.938 182.205,-236.938 182.205,-271.859 182.205,-271.859 182.205,-277.859 176.205,-283.859 170.205,-283.859 170.205,-283.859 118,-283.859 118,-283.859 112,-283.859 106,-277.859 106,-271.859 106,-271.859 106,-236.938 106,-236.938 106,-230.938 112,-224.938 118,-224.938\"/>\n",
       "<text text-anchor=\"middle\" x=\"144.103\" y=\"-271.859\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"106,-264.219 182.205,-264.219 \"/>\n",
       "<text text-anchor=\"middle\" x=\"144.103\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"106,-244.578 182.205,-244.578 \"/>\n",
       "<text text-anchor=\"middle\" x=\"144.103\" y=\"-232.578\" font-family=\"Menlo\" font-size=\"10.00\">+</text>\n",
       "</g>\n",
       "<!-- 140455917316368backward -->\n",
       "<g id=\"node5\" class=\"node\"><title>140455917316368backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M57,-112.719C57,-112.719 109.205,-112.719 109.205,-112.719 115.205,-112.719 121.205,-118.719 121.205,-124.719 121.205,-124.719 121.205,-159.641 121.205,-159.641 121.205,-165.641 115.205,-171.641 109.205,-171.641 109.205,-171.641 57,-171.641 57,-171.641 51,-171.641 45,-165.641 45,-159.641 45,-159.641 45,-124.719 45,-124.719 45,-118.719 51,-112.719 57,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"83.1025\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"45,-152 121.205,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"83.1025\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"45,-132.359 121.205,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"83.1025\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140455917316800backward&#45;&gt;140455917316368backward -->\n",
       "<g id=\"edge4\" class=\"edge\"><title>140455917316800backward&#45;&gt;140455917316368backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M124.16,-224.75C120.284,-218.795 116.37,-212.488 112.959,-206.438 108.361,-198.28 103.824,-189.298 99.7284,-180.727\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"102.887,-179.219 95.4758,-171.648 96.5479,-182.188 102.887,-179.219\"/>\n",
       "<text text-anchor=\"middle\" x=\"134.174\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140455917315888backward -->\n",
       "<g id=\"node6\" class=\"node\"><title>140455917315888backward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"139,-112.719 139,-171.641 215.205,-171.641 215.205,-112.719 139,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"177.103\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;5.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"139,-152 215.205,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"177.103\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"139,-132.359 215.205,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"177.103\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">b</text>\n",
       "</g>\n",
       "<!-- 140455917316800backward&#45;&gt;140455917315888backward -->\n",
       "<g id=\"edge1\" class=\"edge\"><title>140455917316800backward&#45;&gt;140455917315888backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M152.687,-224.727C156.621,-211.588 161.342,-195.821 165.579,-181.667\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"169.046,-182.291 168.561,-171.707 162.34,-180.283 169.046,-182.291\"/>\n",
       "<text text-anchor=\"middle\" x=\"184.174\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140455917316368backward&#45;&gt;140455917314112backward -->\n",
       "<g id=\"edge5\" class=\"edge\"><title>140455917316368backward&#45;&gt;140455917314112backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M71.3968,-112.509C65.9801,-99.2417 59.4697,-83.2956 53.6481,-69.0367\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"56.77,-67.4235 49.7497,-59.4883 50.2893,-70.0694 56.77,-67.4235\"/>\n",
       "<text text-anchor=\"middle\" x=\"85.1743\" y=\"-83.0187\" font-family=\"Menlo\" font-size=\"14.00\">&#45;8.00</text>\n",
       "</g>\n",
       "<!-- 140455917316560backward -->\n",
       "<g id=\"node9\" class=\"node\"><title>140455917316560backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"129.103\" cy=\"-29.9609\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"129.103\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">x2=2.00</text>\n",
       "</g>\n",
       "<!-- 140455917316368backward&#45;&gt;140455917316560backward -->\n",
       "<g id=\"edge8\" class=\"edge\"><title>140455917316368backward&#45;&gt;140455917316560backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M100.885,-112.659C104.207,-106.694 107.461,-100.351 110.103,-94.2188 116.636,-79.0538 121.831,-60.9343 125.126,-47.9823\"/>\n",
       "</g>\n",
       "<!-- 140455917316464backward -->\n",
       "<g id=\"node7\" class=\"node\"><title>140455917316464backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"235.103\" cy=\"-254.398\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"235.103\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">y2=4.00</text>\n",
       "</g>\n",
       "<!-- 140455917315552backward&#45;&gt;140455917316800backward -->\n",
       "<g id=\"edge2\" class=\"edge\"><title>140455917315552backward&#45;&gt;140455917316800backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M184.721,-337.115C180.608,-331.149 176.487,-324.8 172.959,-318.656 168.373,-310.671 163.958,-301.827 160.019,-293.345\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"163.098,-291.66 155.79,-283.989 156.719,-294.543 163.098,-291.66\"/>\n",
       "<text text-anchor=\"middle\" x=\"194.174\" y=\"-307.456\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140455917315552backward&#45;&gt;140455917316464backward -->\n",
       "<g id=\"edge3\" class=\"edge\"><title>140455917315552backward&#45;&gt;140455917316464backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M213.646,-336.946C219.05,-316.408 226.143,-289.449 230.664,-272.269\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fbe70f961f0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用x2，y2传播一次\n",
    "loss = 0.5 * mse([model.error(x2, y2)])\n",
    "loss.backward()\n",
    "draw_graph(loss, 'backward')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# 固定随机种子，使得运行结果可以稳定复现\n",
    "torch.manual_seed(1024)\n",
    "# 产生训练用的数据\n",
    "x_origin = torch.linspace(100, 300, 200)\n",
    "# 将变量X归一化，否则梯度下降法很容易不稳定\n",
    "x = (x_origin - torch.mean(x_origin)) / torch.std(x_origin)\n",
    "epsilon = torch.randn(x.shape)\n",
    "y = 10 * x + 5 + epsilon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 4, Result: y = 3.12 * x + -1.99\n",
      "Step 8, Result: y = 3.48 * x + -2.28\n",
      "Step 12, Result: y = 3.22 * x + -1.97\n",
      "Step 16, Result: y = 2.85 * x + -1.22\n",
      "Step 20, Result: y = 2.68 * x + -0.23\n",
      "Step 24, Result: y = 2.92 * x + 1.08\n",
      "Step 28, Result: y = 3.74 * x + 2.61\n",
      "Step 32, Result: y = 5.07 * x + 4.15\n",
      "Step 36, Result: y = 6.73 * x + 5.52\n",
      "Step 40, Result: y = 8.22 * x + 6.48\n",
      "Step 44, Result: y = 9.36 * x + 5.75\n",
      "Step 48, Result: y = 9.75 * x + 5.42\n",
      "Step 52, Result: y = 9.88 * x + 5.28\n",
      "Step 56, Result: y = 9.89 * x + 5.26\n",
      "Step 60, Result: y = 9.89 * x + 5.20\n",
      "Step 64, Result: y = 9.88 * x + 5.18\n",
      "Step 68, Result: y = 9.88 * x + 5.17\n",
      "Step 72, Result: y = 9.84 * x + 5.14\n",
      "Step 76, Result: y = 9.86 * x + 5.15\n",
      "Step 80, Result: y = 9.94 * x + 5.21\n"
     ]
    }
   ],
   "source": [
    "# 生成模型\n",
    "model = Linear()\n",
    "# 定义每批次用到的数据量\n",
    "batch_size = 20\n",
    "# 定义每批次梯度累积的次数\n",
    "gradient_accumulation_iter = 4\n",
    "# 每次反向传播的数据量\n",
    "micro_size = int(batch_size / gradient_accumulation_iter)\n",
    "learning_rate = 0.1\n",
    "\n",
    "for t in range(20 * gradient_accumulation_iter):\n",
    "    # 选取当前批次的数据，用于训练模型\n",
    "    ix = (t * micro_size) % len(x)\n",
    "    xx = x[ix: ix + micro_size]\n",
    "    yy = y[ix: ix + micro_size]\n",
    "    # 计算当前批次数据的损失\n",
    "    loss = mse([model.error(_x, _y) for _x, _y in zip(xx, yy)])\n",
    "    # 根据梯度累积的次数，调整模型损失的权重\n",
    "    loss *= 1 / gradient_accumulation_iter\n",
    "    # 计算损失函数的梯度\n",
    "    loss.backward()\n",
    "    if (t + 1) % gradient_accumulation_iter == 0:\n",
    "        # 迭代更新模型参数的估计值\n",
    "        model.a -= learning_rate * model.a.grad\n",
    "        model.b -= learning_rate * model.b.grad\n",
    "        # 将使用完的梯度清零\n",
    "        model.a.grad = 0.0\n",
    "        model.b.grad = 0.0\n",
    "        print(f'Step {t + 1}, Result: {model.string()}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
